{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "dfccd8e6",
   "metadata": {},
   "source": [
    "# LLM Encrypted Token Generation\n",
    "\n",
    "This notebook shows how to configure a GPT2 model to generate text based on an encrypted prompt. The GPT2 model shown\n",
    "here runs some layers on the client-side machine and some on the server-side using FHE:\n",
    "- on the client-side: non-linear layers, such as attention, normalization and activation functions\n",
    "- on the server-side: all linear layers that have trained weights\n",
    "\n",
    "To generate one token, the model shown here requires around 11 seconds on a desktop GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "eca73e44",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x73526e76a630>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Import necessary libraries\n",
    "import os\n",
    "import random\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, Conv1D\n",
    "\n",
    "from concrete.ml.torch.hybrid_model import HybridFHEModel\n",
    "\n",
    "# Set random seed for reproducibility\n",
    "SEED = 0\n",
    "torch.manual_seed(SEED)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "c082411e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_and_print(prompt, model, tokenizer, seed=None, max_new_tokens=30):\n",
    "    \"\"\"\n",
    "    Generates text based on the provided prompt and prints both the prompt and the generated text.\n",
    "\n",
    "    Args:\n",
    "        prompt (str): The input prompt to generate text from.\n",
    "        model: The pre-trained language model.\n",
    "        tokenizer: The tokenizer associated with the model.\n",
    "        seed (int, optional): Seed for random number generators to ensure reproducibility.\n",
    "        max_new_tokens (int, optional): Maximum number of tokens to generate. Defaults to 30.\n",
    "    Returns:\n",
    "        str: The generated text (response only, without the prompt).\n",
    "    \"\"\"\n",
    "    # Set the environment variable for CuBLAS deterministic behavior\n",
    "    os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n",
    "\n",
    "    # Set the random seed for reproducibility\n",
    "    if seed is not None:\n",
    "        random.seed(seed)\n",
    "        np.random.seed(seed)\n",
    "        torch.manual_seed(seed)\n",
    "        if torch.cuda.is_available():\n",
    "            torch.cuda.manual_seed_all(seed)\n",
    "\n",
    "    # Encode the input prompt\n",
    "    inputs = tokenizer.encode_plus(prompt, return_tensors=\"pt\")\n",
    "\n",
    "    # Generate text\n",
    "    with torch.no_grad():\n",
    "        output = model.generate(\n",
    "            input_ids=inputs[\"input_ids\"],\n",
    "            attention_mask=inputs[\"attention_mask\"],\n",
    "            max_new_tokens=max_new_tokens,\n",
    "            top_p=0.9,\n",
    "            temperature=0.6,\n",
    "            do_sample=True,\n",
    "            pad_token_id=tokenizer.eos_token_id,\n",
    "        )\n",
    "\n",
    "    # Get only the newly generated tokens\n",
    "    input_length = inputs[\"input_ids\"].shape[1]\n",
    "    generated_ids = output[0, input_length:]\n",
    "    generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()\n",
    "\n",
    "    # Print the prompt and generated text\n",
    "    print(f\"Prompt: {prompt}\")\n",
    "    print(f\"Response: {generated_text}\\n\")\n",
    "\n",
    "    return generated_text"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "794d9a40",
   "metadata": {},
   "source": [
    "## Load the model for inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "8b965a1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load pre-trained GPT-2 model and tokenizer\n",
    "model_name = \"gpt2\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name)\n",
    "\n",
    "# Ensure tokenizer has a pad token\n",
    "if tokenizer.pad_token is None:\n",
    "    tokenizer.pad_token = tokenizer.eos_token\n",
    "model.config.pad_token_id = model.config.eos_token_id\n",
    "\n",
    "# Freeze model weights\n",
    "for param in model.parameters():\n",
    "    param.requires_grad = False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d798bd0",
   "metadata": {},
   "source": [
    "## Generate some tokens on a clear-text prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "2337a6b4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prompt: Programming is\n",
      "Response: a skill you need to learn to master.\n",
      "\n",
      "Learn to code\n",
      "\n",
      "There are a lot of different ways to learn programming.\n",
      "\n",
      "The\n",
      "\n"
     ]
    }
   ],
   "source": [
    "_ = generate_and_print(prompt=\"Programming is\", model=model, tokenizer=tokenizer, seed=SEED)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90ed3cbf",
   "metadata": {},
   "source": [
    "## Configure the layers that execute remotely\n",
    "\n",
    "All modules performing linear transformations in the GPT-2 model, such as Linear and Conv1d layers, can be identified by their Python type.\n",
    "\n",
    "In practice, a Conv1d layer with a kernel size of 1 behaves like a fully connected layer and is therefore executed as a matrix multiplication, just like a standard Linear layer.\n",
    "\n",
    "The Concrete ML Extensions backend supports executing such layers as encrypted-clear matrix multiplications, enabling secure and fast inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a138d226",
   "metadata": {},
   "outputs": [],
   "source": [
    "remote_names = []\n",
    "for name, module in model.named_modules():\n",
    "    if isinstance(module, (torch.nn.Linear, Conv1D)):\n",
    "        remote_names.append(name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ae2094a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the HybridFHEModel with the specified remote modules\n",
    "hybrid_model = HybridFHEModel(model, module_names=remote_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b81ea79f",
   "metadata": {},
   "source": [
    "## Compile the model\n",
    "\n",
    "This step determines the quantization parameters needed for inference. Dynamic quantization \n",
    "is used and ensures quantization parameters are customized for each token of the token sequence."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "20dfe2d8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6b36294657bb457687a13d7b043298e5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Compiling FHE layers:   0%|          | 0/49 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "BLOCK_SIZE = 32\n",
    "# Prepare input data for calibration\n",
    "input_tensor = torch.randint(0, tokenizer.vocab_size, (256, BLOCK_SIZE), dtype=torch.long)\n",
    "\n",
    "# Calibrate and compile the model\n",
    "hybrid_model.compile_model(input_tensor, n_bits=8, use_dynamic_quantization=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65d448c8",
   "metadata": {},
   "source": [
    "## Generate a few encrypted tokens based on an encrypted prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "3e91ad0b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prompt: Programming is\n",
      "Response: one of those things that, the first time around, the development team, the full-time team, would, in theory, acknowledge a\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set FHE mode to disable for text generation\n",
    "hybrid_model.set_fhe_mode(\"execute\")\n",
    "\n",
    "_ = generate_and_print(\n",
    "    prompt=\"Programming is\", model=hybrid_model.model, tokenizer=tokenizer, seed=SEED\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3ab85824",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "This notebook shows how, with just a few lines of code, a GPT-2 model can be converted to run on encrypted data."
   ]
  }
 ],
 "metadata": {
  "execution": {
   "timeout": 10800
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
